-
-
Notifications
You must be signed in to change notification settings - Fork 5.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Speculative Decoding] Medusa Implementation with Top-1 proposer #4978
Conversation
@abhigoyal1997 Awesome PR.Do you have any plans regarding support for tree-style speculation and verification? |
Hi @caddfa31434 |
Hi @abhigoyal1997 We recently had the experience of developing a TreeMask-based version of Medusa on our internal inference framework, and it is expected to be ready by the end of the month. Here are some suggestions. There are three main differences between the TreeMask version and the non-TreeMask version (three modification points):
Additionally, EAGLE is a technology worth considering for selection. Therefore, I suggest you make some abstractions when implementing Medusa, which will make future integration with EAGLE easier. |
@zhyncs Thanks for the insights into tree-mask-based version. |
yep. And in our scenario, EAGLE performs better than Medusa. |
@cadedaniel This is complete. Can you please review it? |
Thanks for the contribution! Will take a look today or tomorrow. |
Indeed, there are several challenges in the tree-style speculative decoding, including but not limited to what mentioned here. Especially when large batch size is considered, different request may have different acceptance path in the candidate tree. How to efficiently process them will be an issue. I'm focusing on solving them right now, and should be able to upstream my solution to vllm soon. See also #4669 (comment)
|
+1; I suggest we generalize top-1 and top-k proposing scoring (including defragmentation of accepted KV). then we can use top-1 and top-k implementations with different spec proposal methods (draft, medusa, eagle, ngram, etc). also, we can wait for masking in kernels; we can also implement it in batch expansion style. it won't be as performant but could be a faster way to get everything built as we can add in the kernel support when it's ready. |
The performance will be particularly poor. I don't recommend doing this. |
At the same time, supporting TreeMask in the Attention Kernel is not as much work as imagined. The prerequisite is to understand the implementation of the original casual mask. If you're interested, we can discuss the details further. |
From our internal experience, the real challenge is not integrating with continuous batching, but rather compatibility with existing features such as chunked prefill. |
About the treeMask in attention kernel, indeed we can combine the effort. Take a look at the discussion here: Dao-AILab/flash-attention#924. The flash-attention repo is also calling for such contribution. The api design can refer to the huggingface's implementation: huggingface/transformers#27539. This efficient tree attention kernel will be a crucial factor in tree-style speculative decoding, which we should consider prioritize. |
Right, generally speaking, speculative decoding (not necessarily the tree-style one) and chunked prefill both try to utilize the communication-bound computation ability and fight with each other. But there is a slight difference in their application scenarios. Chunked prefill is applied for long prompt input, while speculative decoding focuses on accelerating at small batch sizes. The incompatibility mentioned here is more of a trade-off between these two different scenarios and how to allocate the communication-bound computation budget. |
In fact, our implementation is not based on Dao-AILab/flash-attention, but on the TurboMind 2.1 Attention Kernel, which was written from scratch and its performance is about 10% better than Dao-AILab/flash-attention. The changes related to casual mask are roughly as follows: // original
__device__ void ApplyCasualMask(FragS& frag_S, int offset_Q, int offset_K)
{
Impl::ForeachS(frag_S, [&](int hi, int qi, int si, int ri, float& score) {
if (offset_Q + qi < offset_K + si) {
score -= std::numeric_limits<float>::infinity();
}
});
} // modified
__device__ void ApplyCasualMask(
FragS& frag_S, int offset_Q, int offset_K, const int* medusa_mask, int his_len, int input_len, int query_idx)
{
Impl::ForeachS(frag_S, [&](int hi, int qi, int si, int ri, float& score) {
if (medusa_mask) {
int rel_pos_q = offset_Q + qi - his_len;
int rel_pos_k = offset_K + si - his_len;
if (0 <= rel_pos_q && rel_pos_q < input_len && 0 <= rel_pos_k && rel_pos_k < input_len) {
if (medusa_mask[rel_pos_q * input_len + rel_pos_k] == 0) {
score -= std::numeric_limits<float>::infinity();
}
}
else {
if (offset_Q + qi < offset_K + si) {
score -= std::numeric_limits<float>::infinity();
}
}
}
else {
if (offset_Q + qi < offset_K + si) {
score -= std::numeric_limits<float>::infinity();
}
}
});
} |
This is related to the implementation of different frameworks, such as whether the previous framework has a sufficiently good abstract design and whether it is convenient for future expansion. It cannot be generalized. In comparison, vLLM is indeed more user-friendly in terms of secondary development difficulty and also has relatively strong scalability. |
If you're interested in combining chunked prefill and spec decode, see #5016. We have a naive dynamic speculation length policy which disables spec decode when the batch size gets too large. |
Hi @cadedaniel |
@cadedaniel FlashInfer supports custom mask now flashinfer-ai/flashinfer#266 |
@abhigoyal1997 May you resolve the conflicts first. Thanks. |
@zhyncs Thanks for pointing out. I've resolved the conflicts 👍 |
@abhigoyal1997 It seems that there are some issues with TP > 1. |
@caddfa31434 Thanks for testing and catching this. The problem was with the order of execution and hidden_states broadcasting in non-driver workers. The latest commit should fix these issues (I have tested for TP = 2). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the contribution -- glad to finally see Medusa working in open-source vLLM. Adding high-level feedback. Some other questions:
- Can we add an e2e test with Medusa? we should expect greedy generation with Medusa (temp=0) to be equal to non-spec decode cases. You can follow this as an example.
def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1( - We'll also want to cover cases like cuda graph / tp>1 / other non-greedy sampling params.
- The biggest concern I have with this PR is the modification of prepare inputs for Medusa-specific models. It seems this PR is introducing two new things to prepare inputs -- allow on-GPU inputs, and also model-specific input config. Can we separate out these changes into their own PR to make things simpler? Additionally, can you walk me through the alternatives to model-specific input config in prepare inputs?
vllm/config.py
Outdated
@@ -137,6 +139,13 @@ def __init__( | |||
sliding_window_len=self.get_hf_config_sliding_window()) | |||
self.served_model_name = get_served_model_name(model, | |||
served_model_name) | |||
|
|||
self.extra_inputs: Dict[str, Tuple[Tuple[int], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we list out the schema of what's allowed here for Medusa?
vllm/config.py
Outdated
@@ -321,6 +331,10 @@ def get_num_layers(self, parallel_config: "ParallelConfig") -> int: | |||
total_num_hidden_layers = self.hf_text_config.num_hidden_layers | |||
return total_num_hidden_layers // parallel_config.pipeline_parallel_size | |||
|
|||
def set_num_lookahead_tokens(self, num_lookahead_tokens: int): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you describe the relationship between num_lookahead_tokens and num medusa heads?
For context, the num_lookahead_tokens
value is used to allocate KV space for speculative tokens. Since Medusa does not use KV, we shouldn't require this to be equal to num heads.
logits = torch.stack(logits, dim=0).float() | ||
logprobs = torch.log_softmax(logits, dim=-1) | ||
token_ids = logits.argmax(-1) # support only top-1 for now | ||
probs = torch.softmax(logits, dim=-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we use the lossless rejection sampler, we will have to run vLLM's standard sampling routine here -- the probability distribution must be modified in the same way as the scoring probability distributions, else you will get distributional drift in the output.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you please elaborate on the distribution shift? The tokens from the draft model are either accepted or rejected based on target model distribution, right? So even if the tokens from the draft are from a slightly different distribution, the final output should still match the target model distribution due to rejection. Is this understanding wrong or am I missing something?
The issue with using the standard sampling is that it was causing too much overhead. So if we do need to use it, we might need some optimizations there to get some speed-up out of Medusa.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's one case that I have noticed generates different tokens sometimes (not sure if this is what you are referring to though).
If without Medusa the logits of top-2 tokens have very close values (or same), then with Medusa those values sometimes change a little bit (I don't know why this is happening since Medusa shouldn't affect the output logits of the target model). This causes different tokens to be preferred by the target model, even for greedy sampling, depending on how those values change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I realised this was happening because of bf16 precision, not seeing any such shift when using fp32.
@LiuXiaoxuanPKU Can you please take a look? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the late review! LGTM! But I have some minor questions & comments:
- It seems only greedy sampling is supported. Standard sampling is not supported.
- Could you add some end to end tests to make sure medusa generates almost same results as without speculative decoding? No need to be very strict. Take a look at this.
Happy to get this pr merged soon, sorry for the delay!
Thanks @LiuXiaoxuanPKU for the review!
Standard sampling works as well, I have the results showing speed-up with temperature=1 in the PR description. Since I have only added the top-1 proposal candidate in this PR, sampling in the Medusa model is just using the argmax (we can extend this to top-k when we add tree speculation). My thought was that even when sampling with a temperature of 1, the target model would more likely choose the top-1 token because that still has the highest probability, so if we choose the top-1 token from Medusa, it still has more chance of being accepted than other tokens (the chance that token sampled from target and Medusa head would match when both are random would be much lower). Am I missing something in this argument? Even in Medusa paper, the candidates are only formed by top-k tokens.
Thanks for the reference! I've added similar test for Medusa here: https://github.com/flipkart-incubator/vllm/blob/medusa/tests/spec_decode/e2e/test_medusa_correctness.py
|
…m-project#4978) (cherry picked from commit 2416b26)
Hi, I assume that you got Mistral-7B-Instruct-v0.2 + medusa from https://huggingface.co/text-generation-inference/Mistral-7B-Instruct-v0.2-medusa, and for Meta-Llama-3-8B-Instruct + medusa, you trained Mudusa-1 by yourself using the codebases from https://github.com/FasterDecoding/Medusa/tree/main, but how many medusa heads you used for training? |
Hi, it seems that compared to the speedups of MLPspeculator, the speedups are much smaller: #4947. Do you have any insights on when we should use medusa instead of MLPspeculator? |
Hi @hustxiayang
I am not sure whether there are specific guidelines on when one should be used over the other. As for the performance, it depends on the dataset used, training method etc. and the results I posted in the description are for a baseline model I trained which can be improved with more work. Tree speculation and verification using MQA should improve the performance as well. |
Hi, it seems that this implementation is based on Medusa version 1, which loads lm_heads for each medusa heads. In Medusa version 2, it proposed to reuse the lm_heads from the base model. Have you already investigated on this and do you plan to implement it? Thanks! |
…m-project#4978) Signed-off-by: Alvant <alvasian@yandex.ru>
This PR implements the Medusa approach to generate speculations using top-1 predictions of the heads.
For Mistral-7B-Instruct-v0.2 and Meta-Llama-3-8B-Instruct on an H100 card, the following are the throughput numbers (total tokens generated/sec) when tested on MT-Bench:
So for smaller batch sizes, we see improvement in tokens generated per sec.
Medusa heads for these models were trained using a set of public instructions. I am working on making those available via Huggingface Hub as well.
With tree-style speculation and verification, this should give even higher improvements.
FIX #1023
FIX #4669
Fix FasterDecoding/Medusa#41